# -*- coding: utf-8 -*-  
'''
bert相关预处理

@author: luoyi
Created on 2021年4月16日
'''
import tensorflow as tf


#    根据词编码求掩码
def padding_mask(inputs):
    '''根据词编码求掩码
        @param inputs: Tensor(batch_size, max_sen_len)    [PAD]为0
        @return: Tensor(batch_size, max_sen_len)    有词的地方为0，没词的地方为1（方便后面*1e-9）
    '''
    mask = tf.cast(tf.equal(inputs, 0), dtype=tf.float32)
    #    扩展为[batch_size, 1, 1, sentence_maxlen]，用的时候自行广播
    mask = mask[:, tf.newaxis, tf.newaxis, :]
    return mask
